{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T01:42:25.413605Z",
     "start_time": "2019-05-22T01:42:23.893315Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"BERT finetuning runner.\"\"\"\n",
    "\n",
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import collections\n",
    "import csv\n",
    "import os\n",
    "import modeling\n",
    "import optimization\n",
    "import tokenization\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn import metrics\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = \"0,1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameter Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T01:43:31.325188Z",
     "start_time": "2019-05-22T01:43:31.318808Z"
    }
   },
   "outputs": [],
   "source": [
    "data_dir = '../data/atis'\n",
    "log_dir = '../logs/joint_model'\n",
    "\n",
    "bert_model_dir = '../checkpoints/cased_L-12_H-768_A-12'\n",
    "bert_config_file = os.path.join(bert_model_dir, 'bert_config.json')\n",
    "vocab_file = os.path.join(bert_model_dir, 'vocab.txt')\n",
    "init_checkpoint = os.path.join(bert_model_dir, 'bert_model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T01:43:20.215688Z",
     "start_time": "2019-05-22T01:43:20.208285Z"
    }
   },
   "outputs": [],
   "source": [
    "num_gpus = 4\n",
    "do_lower_case = False\n",
    "do_train = True\n",
    "do_eval = True\n",
    "do_predict = True\n",
    "\n",
    "max_seq_length = 50\n",
    "train_batch_size = 32\n",
    "eval_batch_size = 8\n",
    "predict_batch_size = 100\n",
    "learning_rate = 5e-5\n",
    "num_train_epochs = 1\n",
    "warmup_proportion = 0.1\n",
    "save_checkpoints_steps = 1000\n",
    "log_step_count_steps = 10\n",
    "save_summary_steps = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### InputExample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T01:28:07.067314Z",
     "start_time": "2019-05-22T01:28:07.059244Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "class InputExample(object):\n",
    "    \"\"\"A single training/test example for simple sequence classification.\"\"\"\n",
    "\n",
    "    def __init__(self, guid, text, tags, label):\n",
    "        \"\"\"Constructs a InputExample.\n",
    "\n",
    "        Args:\n",
    "          guid: Unique id for the example.\n",
    "          text: string. The untokenized text of the first sequence.\n",
    "          label: The label of the example. This should be\n",
    "            specified for train and dev examples, but not for test examples.\n",
    "        \"\"\"\n",
    "        self.guid = guid\n",
    "        self.text = text\n",
    "        self.tags = tags\n",
    "        self.label = label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T01:29:13.770194Z",
     "start_time": "2019-05-22T01:29:13.762420Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "class InputFeatures(object):\n",
    "    \"\"\"A single set of features of data.\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 input_ids,\n",
    "                 input_mask,\n",
    "                 segment_ids,\n",
    "                 tags_ids,\n",
    "                 label_id,\n",
    "                 is_real_example=True):\n",
    "        self.input_ids = input_ids\n",
    "        self.input_mask = input_mask\n",
    "        self.segment_ids = segment_ids\n",
    "        self.tags_ids = tags_ids\n",
    "        self.label_id = label_id\n",
    "        self.is_real_example = is_real_example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DataProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T01:41:31.627487Z",
     "start_time": "2019-05-22T01:41:31.600203Z"
    }
   },
   "outputs": [],
   "source": [
    "class DataProcessor(object):\n",
    "    def __init__(self, data_dir):\n",
    "        self.data_dir = data_dir\n",
    "        self.train_path = os.path.join(self.data_dir, \"train.tsv\")\n",
    "        self.dev_path = os.path.join(self.data_dir, \"dev.tsv\")\n",
    "        self.test_path = os.path.join(self.data_dir, \"test.tsv\")\n",
    "\n",
    "    def _read_tsv(cls, input_file, quotechar=None):\n",
    "        \"\"\"Reads a tab separated value file.\"\"\"\n",
    "        with tf.gfile.Open(input_file, \"r\") as f:\n",
    "            reader = csv.reader(f, delimiter=\"\\t\", quotechar=quotechar)\n",
    "            lines = []\n",
    "            for line in reader:\n",
    "                lines.append(line)\n",
    "            return lines\n",
    "\n",
    "    def _create_examples(self, lines, set_type):\n",
    "        \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
    "        examples = []\n",
    "        for (i, line) in enumerate(lines):\n",
    "            guid = \"%s-%s\" % (set_type, i)\n",
    "            label = tokenization.convert_to_unicode(line[0])\n",
    "            text = tokenization.convert_to_unicode(line[1])\n",
    "            tags = tokenization.convert_to_unicode(line[2])\n",
    "            examples.append(InputExample(guid=guid, text=text, tags=tags, label=label))\n",
    "        return examples\n",
    "\n",
    "    def get_train_examples(self):\n",
    "        return self._create_examples(self._read_tsv(self.train_path), \"train\")\n",
    "\n",
    "    def get_dev_examples(self):\n",
    "        return self._create_examples(self._read_tsv(self.dev_path), \"dev\")\n",
    "\n",
    "    def get_test_examples(self):\n",
    "        return self._create_examples(self._read_tsv(self.test_path), \"test\")\n",
    "\n",
    "    def get_labels_info(self):\n",
    "        labels = []\n",
    "        tags = []\n",
    "        label_map = {}\n",
    "        tags_map = {}\n",
    "        label_map_file = os.path.join(log_dir, \"label_map.txt\")\n",
    "        tags_map_file = os.path.join(log_dir, \"tags_map.txt\")\n",
    "        lines = self._read_tsv(self.train_path) + \\\n",
    "                self._read_tsv(self.dev_path) + \\\n",
    "                self._read_tsv(self.test_path)\n",
    "\n",
    "        for line in lines:\n",
    "            labels += line[0].strip().split()\n",
    "            tags += line[2].strip().split()\n",
    "            \n",
    "        tags.append(\"X\")\n",
    "        labels = sorted(set(labels), reverse=False)\n",
    "        tags = sorted(set(tags), reverse=False)\n",
    "        num_labels = sorted(set(labels), reverse=True).__len__()\n",
    "        num_tags = sorted(set(tags), reverse=True).__len__()\n",
    "\n",
    "        with tf.gfile.GFile(label_map_file, \"w\") as writer:\n",
    "            for (i, label) in enumerate(labels):\n",
    "                label_map[label] = i\n",
    "                writer.write(\"{}:{}\\n\".format(i, label))\n",
    "        \n",
    "        with tf.gfile.GFile(tags_map_file, \"w\") as writer:\n",
    "            for (i, tag) in enumerate(tags):\n",
    "                label_map[tag] = i\n",
    "                writer.write(\"{}:{}\\n\".format(i, tag))\n",
    "        return label_map, num_labels, tags_map, num_tags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T02:37:36.002705Z",
     "start_time": "2019-05-22T02:37:35.976317Z"
    }
   },
   "outputs": [],
   "source": [
    "def convert_single_example(ex_index, example, label_map, tags_map,\n",
    "                           max_seq_length, tokenizer):\n",
    "    tokens_list = example.text.split(\" \")\n",
    "    tags_list = example.tags.split(\" \")\n",
    "    tokens = []\n",
    "    tags = []\n",
    "    for i, (word, tag) in enumerate(zip(tokens_list, tags_list)):\n",
    "        token = tokenizer.tokenize(word)\n",
    "        tokens.extend(token)\n",
    "        for i, _ in enumerate(token):\n",
    "            if i == 0:\n",
    "                tags.append(tag)\n",
    "            else:\n",
    "                tags.append(\"X\")\n",
    "\n",
    "    # only Account for [CLS]  with \"- 1\".\n",
    "    if len(tokens) >= max_seq_length - 1:\n",
    "        tokens = tokens[0:(max_seq_length - 1)]\n",
    "        tags = labels[0:(max_seq_length - 1)]\n",
    "\n",
    "    tokens.insert(0, \"[CLS]\")\n",
    "    tags.insert(0, \"O\")\n",
    "    segment_ids = [0] * max_seq_length\n",
    "    tags_ids = [label_map[tag] for tag in tags]\n",
    "    input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_ids)\n",
    "    label_id = label_map[example.label]\n",
    "\n",
    "    # Zero-pad up to the sequence length.\n",
    "    while len(input_ids) < max_seq_length:\n",
    "        input_ids.append(0)\n",
    "        input_mask.append(0)\n",
    "        tags_ids.append(tags_map[\"O\"])\n",
    "\n",
    "    assert len(input_ids) == max_seq_length\n",
    "    assert len(input_mask) == max_seq_length\n",
    "    assert len(segment_ids) == max_seq_length\n",
    "    assert len(tags_ids) == max_seq_length\n",
    "\n",
    "    if ex_index < 5:\n",
    "        tf.logging.info(\"*** Example ***\")\n",
    "        tf.logging.info(\"guid: %s\" % (example.guid))\n",
    "        tf.logging.info(\"label: %s\" % (example.label))\n",
    "        tf.logging.info(\"tokens: %s\" % \" \".join(\n",
    "            [tokenization.printable_text(x) for x in tokens]))\n",
    "        tf.logging.info(\n",
    "            \"input_ids: %s\" % \" \".join([str(x) for x in input_ids]))\n",
    "        tf.logging.info(\n",
    "            \"input_mask: %s\" % \" \".join([str(x) for x in input_mask]))\n",
    "        tf.logging.info(\n",
    "            \"segment_ids: %s\" % \" \".join([str(x) for x in segment_ids]))\n",
    "        tf.logging.info(\"tags_ids: %s\" % \" \".join([str(x) for x in tags_ids]))\n",
    "    feature = InputFeatures(\n",
    "        label_id=label_id,\n",
    "        input_ids=input_ids,\n",
    "        input_mask=input_mask,\n",
    "        segment_ids=segment_ids,\n",
    "        tags_ids=tags_ids)\n",
    "    return feature"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert_examples_to_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T02:34:40.781753Z",
     "start_time": "2019-05-22T02:34:40.766535Z"
    }
   },
   "outputs": [],
   "source": [
    "def file_based_convert_examples_to_features(\n",
    "        examples, label_map, tags_map, max_seq_length, tokenizer, output_file):\n",
    "    \"\"\"Convert a set of `InputExample`s to a TFRecord file.\"\"\"\n",
    "    writer = tf.python_io.TFRecordWriter(output_file)\n",
    "    for (ex_index, example) in enumerate(examples):\n",
    "        if ex_index % 2000 == 0:\n",
    "            tf.logging.info(\n",
    "                \"Writing example %d of %d\" % (ex_index, len(examples)))\n",
    "        feature = convert_single_example(ex_index, example, label_map,\n",
    "                                         tags_map, max_seq_length, tokenizer)\n",
    "\n",
    "        def create_int_feature(values):\n",
    "            f = tf.train.Feature(\n",
    "                int64_list=tf.train.Int64List(value=list(values)))\n",
    "            return f\n",
    "\n",
    "        features = collections.OrderedDict()\n",
    "        features[\"label_id\"] = create_int_feature(feature.label_id)\n",
    "        features[\"input_ids\"] = create_int_feature(feature.input_ids)\n",
    "        features[\"input_mask\"] = create_int_feature(feature.input_mask)\n",
    "        features[\"segment_ids\"] = create_int_feature(feature.segment_ids)\n",
    "        features[\"tags_ids\"] = create_int_feature(feature.tags_ids)\n",
    "        tf_example = tf.train.Example(\n",
    "            features=tf.train.Features(feature=features))\n",
    "        writer.write(tf_example.SerializeToString())\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T02:34:36.738879Z",
     "start_time": "2019-05-22T02:34:36.725009Z"
    }
   },
   "outputs": [],
   "source": [
    "def file_based_input_fn_builder(input_file, seq_length, is_training,\n",
    "                                drop_remainder, batch_size):\n",
    "    \"\"\"Creates an `input_fn` closure to be passed to Estimator.\"\"\"\n",
    "    name_to_features = {\n",
    "        \"label_id\": tf.FixedLenFeature([], tf.int64),\n",
    "        \"input_ids\": tf.FixedLenFeature([seq_length], tf.int64),\n",
    "        \"input_mask\": tf.FixedLenFeature([seq_length], tf.int64),\n",
    "        \"segment_ids\": tf.FixedLenFeature([seq_length], tf.int64),\n",
    "        \"tags_ids\": tf.FixedLenFeature([seq_length], tf.int64),\n",
    "    }\n",
    "\n",
    "    def _decode_record(record, name_to_features):\n",
    "        \"\"\"Decodes a record to a TensorFlow example.\"\"\"\n",
    "        return tf.parse_single_example(record, name_to_features)\n",
    "\n",
    "    def input_fn():\n",
    "        # For training, we want a lot of parallel reading and shuffling.\n",
    "        # For eval, we want no shuffling and parallel reading doesn't matter.\n",
    "        d = tf.data.TFRecordDataset(input_file)\n",
    "        if is_training:\n",
    "            d = d.repeat()\n",
    "            d = d.shuffle(buffer_size=100)\n",
    "        d = d.apply(\n",
    "            tf.data.experimental.map_and_batch(\n",
    "                lambda record: _decode_record(record, name_to_features),\n",
    "                batch_size=batch_size,\n",
    "                drop_remainder=drop_remainder))\n",
    "        return d\n",
    "\n",
    "    return input_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T02:34:32.605883Z",
     "start_time": "2019-05-22T02:34:32.590381Z"
    }
   },
   "outputs": [],
   "source": [
    "def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,\n",
    "                 labels, num_labels):\n",
    "    model = modeling.BertModel(\n",
    "        config=bert_config,\n",
    "        is_training=is_training,\n",
    "        input_ids=input_ids,\n",
    "        input_mask=input_mask,\n",
    "        token_type_ids=segment_ids)\n",
    "\n",
    "    # output_layer = model.get_pooled_output()\n",
    "    output_layer = model.get_sequence_output()\n",
    "\n",
    "    if is_training:\n",
    "        output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)\n",
    "\n",
    "    logits = tf.layers.dense(\n",
    "        inputs=output_layer,\n",
    "        units=num_labels,\n",
    "        use_bias=True,\n",
    "        bias_initializer=tf.zeros_initializer(),\n",
    "        kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))\n",
    "\n",
    "    mask_length = tf.reduce_sum(input_mask, axis=1)\n",
    "\n",
    "    with tf.variable_scope(\"crf_loss\"):\n",
    "        trans = tf.get_variable(\n",
    "            \"transition\",\n",
    "            shape=[num_labels, num_labels],\n",
    "            initializer=tf.contrib.layers.xavier_initializer())\n",
    "    log_likelihood, transition = tf.contrib.crf.crf_log_likelihood(\n",
    "        inputs=logits,\n",
    "        tag_indices=labels,\n",
    "        sequence_lengths=mask_length,\n",
    "        transition_params=trans)\n",
    "    loss = tf.reduce_mean(-log_likelihood)\n",
    "    decode_tags, best_score = tf.contrib.crf.crf_decode(\n",
    "        potentials=logits,\n",
    "        transition_params=transition,\n",
    "        sequence_length=mask_length)\n",
    "    return (loss, logits, decode_tags, mask_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T02:34:20.057265Z",
     "start_time": "2019-05-22T02:34:20.035282Z"
    }
   },
   "outputs": [],
   "source": [
    "def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,\n",
    "                     num_train_steps, num_warmup_steps):\n",
    "    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument\n",
    "\n",
    "        tf.logging.info(\"*** Features ***\")\n",
    "        for name in sorted(features.keys()):\n",
    "            tf.logging.info(\n",
    "                \"  name = %s, shape = %s\" % (name, features[name].shape))\n",
    "\n",
    "        input_ids = features[\"input_ids\"]\n",
    "        input_mask = features[\"input_mask\"]\n",
    "        segment_ids = features[\"segment_ids\"]\n",
    "        label_ids = features[\"label_ids\"]\n",
    "\n",
    "        is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
    "        (total_loss, logits, decode_tags, mask_length) = create_model(\n",
    "            bert_config, is_training, input_ids, input_mask, segment_ids,\n",
    "            label_ids, num_labels)\n",
    "        tvars = tf.trainable_variables()\n",
    "        initialized_variable_names = {}\n",
    "        if init_checkpoint:\n",
    "            (assignment_map, initialized_variable_names\n",
    "             ) = modeling.get_assignment_map_from_checkpoint(\n",
    "                 tvars, init_checkpoint)\n",
    "            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)\n",
    "\n",
    "\n",
    "        tf.logging.info(\"**** Trainable Variables ****\")\n",
    "        for var in tvars:\n",
    "            init_string = \"\"\n",
    "            if var.name in initialized_variable_names:\n",
    "                init_string = \", *INIT_FROM_CKPT*\"\n",
    "            tf.logging.info(\"  name = %s, shape = %s%s\", var.name, var.shape,\n",
    "                            init_string)\n",
    "\n",
    "        output_spec = None\n",
    "        if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "            train_op = optimization.create_optimizer(\n",
    "                total_loss, learning_rate, num_train_steps, num_warmup_steps)\n",
    "            output_spec = tf.estimator.EstimatorSpec(\n",
    "                mode=mode, loss=total_loss, train_op=train_op)\n",
    "        elif mode == tf.estimator.ModeKeys.EVAL:\n",
    "            accuracy = tf.metrics.accuracy(label_ids, decode_tags, input_mask)\n",
    "            evl_metrics = {\n",
    "                'accuracy': accuracy,\n",
    "            }\n",
    "            for metric_name, op in evl_metrics.items():\n",
    "                tf.summary.scalar(metric_name, op[1])\n",
    "            output_spec = tf.estimator.EstimatorSpec(\n",
    "                mode=mode,\n",
    "                loss=total_loss,\n",
    "                eval_metric_ops=evl_metrics)\n",
    "        else:\n",
    "            output_spec = tf.estimator.EstimatorSpec(\n",
    "                mode=mode,\n",
    "                predictions={\n",
    "                    \"predicted_ids\": decode_tags,\n",
    "                    \"label_ids\": label_ids,\n",
    "                    \"input_mask\": mask_length,\n",
    "                    \"input_ids\": input_ids\n",
    "                })\n",
    "        return output_spec\n",
    "\n",
    "    return model_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-22T02:37:40.926897Z",
     "start_time": "2019-05-22T02:37:40.638149Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using config: {'_model_dir': '../logs/joint_model', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': gpu_options {\n",
      "  allow_growth: true\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 10, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f09a4133550>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7f09a1de3158>) includes params argument, but params are not passed to Estimator.\n",
      "INFO:tensorflow:Writing example 0 of 4478\n"
     ]
    },
    {
     "ename": "KeyError",
     "evalue": "'O'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-26-a07c05ab4cbe>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     50\u001b[0m     \u001b[0mtrain_file\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"train.tf_record\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m     file_based_convert_examples_to_features(\n\u001b[0;32m---> 52\u001b[0;31m         train_examples, label_map, tags_map, max_seq_length, tokenizer, train_file)\n\u001b[0m\u001b[1;32m     53\u001b[0m     \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"***** Running training *****\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m     \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"  Num examples = %d\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_examples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-22-71bff200b4ff>\u001b[0m in \u001b[0;36mfile_based_convert_examples_to_features\u001b[0;34m(examples, label_map, tags_map, max_seq_length, tokenizer, output_file)\u001b[0m\n\u001b[1;32m      8\u001b[0m                 \"Writing example %d of %d\" % (ex_index, len(examples)))\n\u001b[1;32m      9\u001b[0m         feature = convert_single_example(ex_index, example, label_map,\n\u001b[0;32m---> 10\u001b[0;31m                                          tags_map, max_seq_length, tokenizer)\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcreate_int_feature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-25-7c4b95c90570>\u001b[0m in \u001b[0;36mconvert_single_example\u001b[0;34m(ex_index, example, label_map, tags_map, max_seq_length, tokenizer)\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0minput_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m         \u001b[0minput_mask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m         \u001b[0mtags_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtags_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"O\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m     \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mmax_seq_length\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m: 'O'"
     ]
    }
   ],
   "source": [
    "tf.logging.set_verbosity(tf.logging.INFO)\n",
    "if not do_train and not do_eval and not do_predict:\n",
    "    raise ValueError(\n",
    "        \"At least one of `do_train`, `do_eval` or `do_predict' must be True.\")\n",
    "tf.gfile.MakeDirs(log_dir)\n",
    "processor = DataProcessor(data_dir)\n",
    "label_map, num_labels, tags_map, num_tags = processor.get_labels_info()\n",
    "tokenization.validate_case_matches_checkpoint(do_lower_case, init_checkpoint)\n",
    "bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n",
    "\n",
    "if max_seq_length > bert_config.max_position_embeddings:\n",
    "    raise ValueError(\"Cannot use sequence length %d because the BERT model \"\n",
    "                     \"was only trained up to sequence length %d\" %\n",
    "                     (max_seq_length, bert_config.max_position_embeddings))\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "    vocab_file=vocab_file, do_lower_case=do_lower_case)\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "\n",
    "run_config = tf.estimator.RunConfig(\n",
    "    model_dir=log_dir,\n",
    "    session_config=config,\n",
    "    save_checkpoints_steps=save_checkpoints_steps,\n",
    "    log_step_count_steps=log_step_count_steps,\n",
    "    save_summary_steps=save_summary_steps)\n",
    "\n",
    "train_examples = None\n",
    "num_train_steps = None\n",
    "num_warmup_steps = None\n",
    "\n",
    "if do_train:\n",
    "    train_examples = processor.get_train_examples()\n",
    "    num_train_steps = int(\n",
    "        len(train_examples) / train_batch_size * num_train_epochs)\n",
    "    num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "\n",
    "model_fn = model_fn_builder(\n",
    "    bert_config=bert_config,\n",
    "    num_labels=num_labels,\n",
    "    init_checkpoint=init_checkpoint,\n",
    "    learning_rate=learning_rate,\n",
    "    num_train_steps=num_train_steps,\n",
    "    num_warmup_steps=num_warmup_steps)\n",
    "\n",
    "estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)\n",
    "\n",
    "# Training\n",
    "if do_train:\n",
    "    train_file = os.path.join(log_dir, \"train.tf_record\")\n",
    "    file_based_convert_examples_to_features(\n",
    "        train_examples, label_map, tags_map, max_seq_length, tokenizer, train_file)\n",
    "    tf.logging.info(\"***** Running training *****\")\n",
    "    tf.logging.info(\"  Num examples = %d\", len(train_examples))\n",
    "    tf.logging.info(\"  Batch size = %d\", train_batch_size)\n",
    "    tf.logging.info(\"  Num steps = %d\", num_train_steps)\n",
    "    train_input_fn = file_based_input_fn_builder(\n",
    "        input_file=train_file,\n",
    "        seq_length=max_seq_length,\n",
    "        is_training=True,\n",
    "        drop_remainder=False,\n",
    "        batch_size=train_batch_size)\n",
    "    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-08T01:33:00.752206Z",
     "start_time": "2019-05-08T01:32:42.992Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if do_eval:\n",
    "    eval_examples = processor.get_dev_examples()\n",
    "    num_actual_eval_examples = len(eval_examples)\n",
    "    eval_file = os.path.join(log_dir, \"eval.tf_record\")\n",
    "    file_based_convert_examples_to_features(\n",
    "        eval_examples, label_map, max_seq_length, tokenizer, eval_file)\n",
    "    tf.logging.info(\"***** Running evaluation *****\")\n",
    "    tf.logging.info(\"  Num examples = %d (%d actual, %d padding)\",\n",
    "                    len(eval_examples), num_actual_eval_examples,\n",
    "                    len(eval_examples) - num_actual_eval_examples)\n",
    "    tf.logging.info(\"  Batch size = %d\", eval_batch_size)\n",
    "\n",
    "    eval_input_fn = file_based_input_fn_builder(\n",
    "        input_file=eval_file,\n",
    "        seq_length=max_seq_length,\n",
    "        is_training=False,\n",
    "        drop_remainder=False,\n",
    "        batch_size=eval_batch_size)\n",
    "\n",
    "    result = estimator.evaluate(input_fn=eval_input_fn)\n",
    "\n",
    "    output_eval_file = os.path.join(log_dir, \"eval_results.txt\")\n",
    "    with tf.gfile.GFile(output_eval_file, \"w\") as writer:\n",
    "        tf.logging.info(\"***** Eval results *****\")\n",
    "        for key in sorted(result.keys()):\n",
    "            tf.logging.info(\"  %s = %s\", key, str(result[key]))\n",
    "            writer.write(\"%s = %s\\n\" % (key, str(result[key])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-08T01:33:00.753301Z",
     "start_time": "2019-05-08T01:32:42.994Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if do_predict:\n",
    "\n",
    "    predict_examples = processor.get_test_examples()\n",
    "    num_actual_predict_examples = len(predict_examples)\n",
    "\n",
    "    predict_file = os.path.join(log_dir, \"predict.tf_record\")\n",
    "    if not tf.gfile.Exists(predict_file):\n",
    "        file_based_convert_examples_to_features(predict_examples, label_map,\n",
    "                                                max_seq_length, tokenizer,\n",
    "                                                predict_file)\n",
    "\n",
    "    tf.logging.info(\"***** Running prediction*****\")\n",
    "    tf.logging.info(\"  Num examples = %d (%d actual, %d padding)\",\n",
    "                    len(predict_examples), num_actual_predict_examples,\n",
    "                    len(predict_examples) - num_actual_predict_examples)\n",
    "    tf.logging.info(\"  Batch size = %d\", predict_batch_size)\n",
    "\n",
    "    predict_input_fn = file_based_input_fn_builder(\n",
    "        input_file=predict_file,\n",
    "        seq_length=max_seq_length,\n",
    "        is_training=False,\n",
    "        drop_remainder=False,\n",
    "        batch_size=predict_batch_size)\n",
    "\n",
    "    result = estimator.predict(input_fn=predict_input_fn)\n",
    "    label_map_new = {v: k for k, v in label_map.items()}\n",
    "\n",
    "    output_predict_file = os.path.join(log_dir, \"test_results.tsv\")\n",
    "    with tf.gfile.GFile(output_predict_file, \"w\") as writer:\n",
    "        true_list = []\n",
    "        predict_list = []\n",
    "        for item in result:\n",
    "            mask_length = item[\"input_mask\"]\n",
    "            label_ids = item[\"label_ids\"][:mask_length]\n",
    "            predicted_ids = item[\"predicted_ids\"][:mask_length]\n",
    "            input_ids = item[\"input_ids\"][:mask_length]\n",
    "            tokens = tokenizer.convert_ids_to_tokens(input_ids)\n",
    "\n",
    "            true_tags = [label_map_new[label] for label in label_ids]\n",
    "            pre_tags = [label_map_new[label] for label in predicted_ids]\n",
    "\n",
    "            for i in range(len(tokens)):\n",
    "                if tokens[i].startswith(\"[CLS]\") or tokens[i].startswith(\"##\"):\n",
    "                    pass\n",
    "                elif true_tags[i].startswith(\"X\") or pre_tags[i].startswith(\"X\"):\n",
    "                    pass               \n",
    "                else:\n",
    "                    writer.write(\"{} {} {}\\n\".format(tokens[i], true_tags[i],\n",
    "                                            pre_tags[i].strip()))\n",
    "            writer.write(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-08T01:33:00.754384Z",
     "start_time": "2019-05-08T01:32:42.997Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python conlleval.py < '../logs/atis-slot/test_results.tsv'"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {
    "height": "143px",
    "width": "246px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "245.587px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "position": {
    "height": "436.725px",
    "left": "1105.45px",
    "right": "20px",
    "top": "109.988px",
    "width": "350px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
